In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go
In [2]:
train_images, train_labels, test_images, test_labels = mnist()

train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)

train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)
In [34]:
def visualize_images(images_tensor, w=28, h=28, col_wrap=5):
    
    img = images_tensor.reshape(-1, w, h)
    
    fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=col_wrap)
    
    item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
    fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]])) 
    
    fig.show()
In [4]:
net_parameters = {
    'w0' : np.random.randn(256, 784) * 0.1,
    'w1' : np.random.randn(256, 256) * 0.1,
    'w2' : np.random.randn(256, 256) * 0.1,
    'w3' : np.random.randn(10, 256) * 0.1,
}
In [6]:
def ReLU(x):
    return jnp.maximum(0,x)

def forward(parameters, x):
    x = x.T
    x = parameters['w0'] @ x
    x = ReLU(x)
    x = parameters['w1'] @ x
    x = ReLU(x)
    x = parameters['w2'] @ x
    x = ReLU(x)
    x = parameters['w3'] @ x
    x = x.T
    return x
In [7]:
def loss(parameters, x, y):
    out = forward(parameters, x)
    out = jax.nn.softmax(out)
    _loss = -(y * jnp.log(out)).sum(axis=-1).mean()
    return _loss

loss(net_parameters, test_images, test_labels)
Out[7]:
Array(2.8375754, dtype=float32)
In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
Out[8]:
Array(0.10371667, dtype=float32)
In [9]:
grad_loss = jax.grad(loss)
lr = 0.1

# keep track of all the previous gradients
grad_history = []

for epoch in range(100):

    p_grad = grad_loss(net_parameters, train_images, train_labels)
    grad_history.append(p_grad)

    net_parameters['w0'] -= lr * p_grad['w0']
    net_parameters['w1'] -= lr * p_grad['w1']
    net_parameters['w2'] -= lr * p_grad['w2']
    net_parameters['w3'] -= lr * p_grad['w3']
        
    print(f"epoch {epoch}")
    print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
    print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
    acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
    print(f"accuracy: {acc}")
    print("\n")
epoch 0
validation loss: 2.9189021587371826
train loss: 2.8750860691070557
accuracy: 0.164000004529953


epoch 1
validation loss: 2.4470043182373047
train loss: 2.4396541118621826
accuracy: 0.2874833345413208


epoch 2
validation loss: 1.8561428785324097
train loss: 1.8763872385025024
accuracy: 0.3936833441257477


epoch 3
validation loss: 1.551874041557312
train loss: 1.5782337188720703
accuracy: 0.5454500317573547


epoch 4
validation loss: 1.3967901468276978
train loss: 1.4261633157730103
accuracy: 0.6118500232696533


epoch 5
validation loss: 1.271924376487732
train loss: 1.3043782711029053
accuracy: 0.6500666737556458


epoch 6
validation loss: 1.1636574268341064
train loss: 1.1976664066314697
accuracy: 0.6835333704948425


epoch 7
validation loss: 1.069806456565857
train loss: 1.1046808958053589
accuracy: 0.708383321762085


epoch 8
validation loss: 0.9889728426933289
train loss: 1.024280309677124
accuracy: 0.7311000227928162


epoch 9
validation loss: 0.919693112373352
train loss: 0.9550483226776123
accuracy: 0.7476666569709778


epoch 10
validation loss: 0.860313355922699
train loss: 0.895563542842865
accuracy: 0.7634833455085754


epoch 11
validation loss: 0.8095411658287048
train loss: 0.8444904685020447
accuracy: 0.7739666700363159


epoch 12
validation loss: 0.7667122483253479
train loss: 0.8009517788887024
accuracy: 0.786133348941803


epoch 13
validation loss: 0.7311145067214966
train loss: 0.7651796340942383
accuracy: 0.7906333208084106


epoch 14
validation loss: 0.7095828056335449
train loss: 0.7416921854019165
accuracy: 0.7956833243370056


epoch 15
validation loss: 0.718275249004364
train loss: 0.7508598566055298
accuracy: 0.7748500108718872


epoch 16
validation loss: 0.8224052786827087
train loss: 0.8473957180976868
accuracy: 0.7249833345413208


epoch 17
validation loss: 1.176560878753662
train loss: 1.199607253074646
accuracy: 0.6542666554450989


epoch 18
validation loss: 1.0818630456924438
train loss: 1.0991448163986206
accuracy: 0.6475833654403687


epoch 19
validation loss: 0.9507499933242798
train loss: 0.9748549461364746
accuracy: 0.6638333201408386


epoch 20
validation loss: 0.7726666927337646
train loss: 0.799832820892334
accuracy: 0.7476000189781189


epoch 21
validation loss: 0.6838322281837463
train loss: 0.7125795483589172
accuracy: 0.7737833261489868


epoch 22
validation loss: 0.6092954874038696
train loss: 0.6370028853416443
accuracy: 0.8159833550453186


epoch 23
validation loss: 0.5797497034072876
train loss: 0.6072186827659607
accuracy: 0.8192499876022339


epoch 24
validation loss: 0.5597652196884155
train loss: 0.5865647196769714
accuracy: 0.8320333361625671


epoch 25
validation loss: 0.5450906753540039
train loss: 0.5715097784996033
accuracy: 0.8308166861534119


epoch 26
validation loss: 0.5375725626945496
train loss: 0.5629200339317322
accuracy: 0.8347166776657104


epoch 27
validation loss: 0.5302454829216003
train loss: 0.5557253360748291
accuracy: 0.8318166732788086


epoch 28
validation loss: 0.5343443155288696
train loss: 0.5580121278762817
accuracy: 0.8284167051315308


epoch 29
validation loss: 0.5323984026908875
train loss: 0.5570401549339294
accuracy: 0.8246833682060242


epoch 30
validation loss: 0.5490544438362122
train loss: 0.5707834959030151
accuracy: 0.8147667050361633


epoch 31
validation loss: 0.5427020788192749
train loss: 0.5665341019630432
accuracy: 0.8155333399772644


epoch 32
validation loss: 0.559158980846405
train loss: 0.5792463421821594
accuracy: 0.8072666525840759


epoch 33
validation loss: 0.5375383496284485
train loss: 0.560635507106781
accuracy: 0.8168333172798157


epoch 34
validation loss: 0.537998616695404
train loss: 0.5572067499160767
accuracy: 0.8160666823387146


epoch 35
validation loss: 0.5104748606681824
train loss: 0.533051609992981
accuracy: 0.8287833333015442


epoch 36
validation loss: 0.5004280209541321
train loss: 0.518977165222168
accuracy: 0.8355666995048523


epoch 37
validation loss: 0.47629404067993164
train loss: 0.4982987940311432
accuracy: 0.8442500233650208


epoch 38
validation loss: 0.46580857038497925
train loss: 0.4837290644645691
accuracy: 0.8530333638191223


epoch 39
validation loss: 0.44755056500434875
train loss: 0.46879786252975464
accuracy: 0.857283353805542


epoch 40
validation loss: 0.43991610407829285
train loss: 0.4572533965110779
accuracy: 0.8648000359535217


epoch 41
validation loss: 0.4263867139816284
train loss: 0.44678544998168945
accuracy: 0.8656499981880188


epoch 42
validation loss: 0.4212598502635956
train loss: 0.4380277991294861
accuracy: 0.8720333576202393


epoch 43
validation loss: 0.41086313128471375
train loss: 0.4304296374320984
accuracy: 0.8713499903678894


epoch 44
validation loss: 0.4074323773384094
train loss: 0.42363497614860535
accuracy: 0.8773166537284851


epoch 45
validation loss: 0.3989733159542084
train loss: 0.41777488589286804
accuracy: 0.8760833144187927


epoch 46
validation loss: 0.3965945243835449
train loss: 0.4122386574745178
accuracy: 0.8813999891281128


epoch 47
validation loss: 0.38940292596817017
train loss: 0.40750280022621155
accuracy: 0.8791666626930237


epoch 48
validation loss: 0.38778334856033325
train loss: 0.40287551283836365
accuracy: 0.8844833374023438


epoch 49
validation loss: 0.3814713656902313
train loss: 0.39894166588783264
accuracy: 0.881933331489563


epoch 50
validation loss: 0.3804478049278259
train loss: 0.39499184489250183
accuracy: 0.8864499926567078


epoch 51
validation loss: 0.37479153275489807
train loss: 0.3916802704334259
accuracy: 0.8835833668708801


epoch 52
validation loss: 0.37427714467048645
train loss: 0.3882569670677185
accuracy: 0.8885666728019714


epoch 53
validation loss: 0.3691239655017853
train loss: 0.38548731803894043
accuracy: 0.8851000070571899


epoch 54
validation loss: 0.3691091239452362
train loss: 0.38252490758895874
accuracy: 0.8900166749954224


epoch 55
validation loss: 0.3643854260444641
train loss: 0.38027825951576233
accuracy: 0.8863833546638489


epoch 56
validation loss: 0.36489102244377136
train loss: 0.3777557909488678
accuracy: 0.8910499811172485


epoch 57
validation loss: 0.36056894063949585
train loss: 0.37602779269218445
accuracy: 0.887416660785675


epoch 58
validation loss: 0.36157673597335815
train loss: 0.3739010989665985
accuracy: 0.8919166922569275


epoch 59
validation loss: 0.35761409997940063
train loss: 0.3726806044578552
accuracy: 0.8880333304405212


epoch 60
validation loss: 0.3591335713863373
train loss: 0.37092408537864685
accuracy: 0.8924833536148071


epoch 61
validation loss: 0.35549429059028625
train loss: 0.3702082335948944
accuracy: 0.8884333372116089


epoch 62
validation loss: 0.35747775435447693
train loss: 0.36875709891319275
accuracy: 0.8926166892051697


epoch 63
validation loss: 0.3540874123573303
train loss: 0.36846068501472473
accuracy: 0.8882166743278503


epoch 64
validation loss: 0.356381356716156
train loss: 0.3671784996986389
accuracy: 0.8929499983787537


epoch 65
validation loss: 0.3530734181404114
train loss: 0.3671029210090637
accuracy: 0.8879666924476624


epoch 66
validation loss: 0.35538172721862793
train loss: 0.36573588848114014
accuracy: 0.8930000066757202


epoch 67
validation loss: 0.3518384099006653
train loss: 0.36550372838974
accuracy: 0.8879833221435547


epoch 68
validation loss: 0.3537236750125885
train loss: 0.36370065808296204
accuracy: 0.8935333490371704


epoch 69
validation loss: 0.3496054708957672
train loss: 0.36286935210227966
accuracy: 0.8889333605766296


epoch 70
validation loss: 0.3506583273410797
train loss: 0.36033743619918823
accuracy: 0.8945333361625671


epoch 71
validation loss: 0.3457734286785126
train loss: 0.35858094692230225
accuracy: 0.8905333280563354


epoch 72
validation loss: 0.3457915782928467
train loss: 0.3552532494068146
accuracy: 0.8962500095367432


epoch 73
validation loss: 0.3401920199394226
train loss: 0.3524875044822693
accuracy: 0.8927666544914246


epoch 74
validation loss: 0.33920755982398987
train loss: 0.3485132157802582
accuracy: 0.8985666632652283


epoch 75
validation loss: 0.33331069350242615
train loss: 0.345060259103775
accuracy: 0.8959000110626221


epoch 76
validation loss: 0.33171555399894714
train loss: 0.34090983867645264
accuracy: 0.9013167023658752


epoch 77
validation loss: 0.3260519504547119
train loss: 0.3372683525085449
accuracy: 0.8986999988555908


epoch 78
validation loss: 0.32424861192703247
train loss: 0.33335527777671814
accuracy: 0.9036499857902527


epoch 79
validation loss: 0.31921058893203735
train loss: 0.32992836833000183
accuracy: 0.9015833139419556


epoch 80
validation loss: 0.31746235489845276
train loss: 0.3264819085597992
accuracy: 0.9054333567619324


epoch 81
validation loss: 0.31322750449180603
train loss: 0.323489248752594
accuracy: 0.9037500023841858


epoch 82
validation loss: 0.31165578961372375
train loss: 0.3205629289150238
accuracy: 0.9069333672523499


epoch 83
validation loss: 0.3081541359424591
train loss: 0.3179991543292999
accuracy: 0.906000018119812


epoch 84
validation loss: 0.3067646920681
train loss: 0.3155311942100525
accuracy: 0.9082000255584717


epoch 85
validation loss: 0.30387166142463684
train loss: 0.3133417069911957
accuracy: 0.907633364200592


epoch 86
validation loss: 0.30262669920921326
train loss: 0.3112283945083618
accuracy: 0.909250020980835


epoch 87
validation loss: 0.3001970052719116
train loss: 0.3093218207359314
accuracy: 0.9088833332061768


epoch 88
validation loss: 0.2990494668483734
train loss: 0.30746665596961975
accuracy: 0.9101166725158691


epoch 89
validation loss: 0.2969615161418915
train loss: 0.30576759576797485
accuracy: 0.9101666808128357


epoch 90
validation loss: 0.29586556553840637
train loss: 0.30409154295921326
accuracy: 0.9107666611671448


epoch 91
validation loss: 0.29403406381607056
train loss: 0.3025375008583069
accuracy: 0.9113500118255615


epoch 92
validation loss: 0.29297715425491333
train loss: 0.3010016977787018
accuracy: 0.9117500185966492


epoch 93
validation loss: 0.2913386821746826
train loss: 0.2995525300502777
accuracy: 0.9119666814804077


epoch 94
validation loss: 0.29030153155326843
train loss: 0.29811716079711914
accuracy: 0.9125833511352539


epoch 95
validation loss: 0.28880614042282104
train loss: 0.29674792289733887
accuracy: 0.9126999974250793


epoch 96
validation loss: 0.2877866327762604
train loss: 0.29539209604263306
accuracy: 0.9132333397865295


epoch 97
validation loss: 0.28640565276145935
train loss: 0.294091135263443
accuracy: 0.9133833646774292


epoch 98
validation loss: 0.2854039669036865
train loss: 0.292802095413208
accuracy: 0.9140333533287048


epoch 99
validation loss: 0.28412362933158875
train loss: 0.2915572226047516
accuracy: 0.9142833352088928


In [10]:
im = 0
visualize_images(test_images[im])
forward(net_parameters, test_images[im])
Out[10]:
Array([-0.6755133 , -4.4337187 , -0.31666017,  1.5255424 , -1.6595467 ,
       -0.63005555, -2.952984  ,  8.591555  , -0.7318695 ,  1.9986991 ],      dtype=float32)
In [11]:
# the magnitude of the gradient at each training step

grad_norms = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}
for grad_vector in grad_history:
    grad_norms['w0'].append(np.linalg.norm(grad_vector['w0'].flatten()))
    grad_norms['w1'].append(np.linalg.norm(grad_vector['w1'].flatten()))
    grad_norms['w2'].append(np.linalg.norm(grad_vector['w2'].flatten()))
    grad_norms['w3'].append(np.linalg.norm(grad_vector['w3'].flatten()))

fig = px.line(grad_norms)
fig.show()
In [12]:
# for each training step, calculate the the angle between the current and previous vector

grad_cosines = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}

grad_angles = {
    'w0':[],
    'w1':[],
    'w2':[],
    'w3':[]
}

for i in range(1,len(grad_history)):
    for key in ['w0','w1','w2','w3']:
        g_i = grad_history[i][key].flatten()
        g_i_norm = g_i / np.linalg.norm(g_i)
        g_im1 = grad_history[i-1][key].flatten()
        g_im1_norm = g_im1 / np.linalg.norm(g_im1)
        cos = g_i_norm @ g_im1_norm 
        grad_cosines[key].append(cos)
        angle = np.degrees(np.arccos(cos))
        grad_angles[key].append(angle)

fig_0 = px.line(grad_cosines, title="Cosine Between Each Previous Gradient")
fig_0.show()

fig_1 = px.line(grad_angles, title="Angle Between Each Previous Gradient")
fig_1.show()
In [45]:
# Here we are going to find the similarity between each gradient, and each other gradient (per weight)
for key in ['w0','w1','w2','w3']:
    # The history of every gradient for this parameter during the training process
    history = [gradient_dict[key] for gradient_dict in grad_history]
    # convert from a list to an numpy array
    history = np.array(history)
    #print(history.shape) # should have a shape: (training_epochs, output_dim, input_dim)
    
    training_epochs, output_dim, input_dim = history.shape
    history = history.reshape(training_epochs, output_dim * input_dim)

    # normalize the gradient vector for each time step
    magnitudes = np.linalg.norm(history, axis=-1)
    history = (history.T / magnitudes).T

    # find the cosine of the angle between the gradient of each step, and each other step
    similarity_matrix = history @ history.T
    fig = px.imshow(similarity_matrix.reshape(100,100), title=f"Similarity Matrix for {key} Gradients")
    fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=100,
            pad=4
        ),
    )
    fig.show()